import shutil
import tempfile
import pybedtools
from numpy import *
from scipy.stats import spearmanr, pearsonr
from pylab import *

from rpy2 import robjects, rinterface
import rpy2.robjects.numpy2ri
robjects.numpy2ri.activate()
from rpy2.robjects.packages import importr
deseq = importr('DESeq2')
glmGamPoi = importr('glmGamPoi')


def calculate_dispersion(counts1, counts2):
    n = len(counts1)
    assert len(counts2) == n
    counts = zeros((n, 2), int)
    counts[:, 0] = counts1
    counts[:, 1] = counts2
    result = glmGamPoi.glm_gp(counts, overdispersion='global')
    assert len(result.rx['overdispersions']) == 1
    overdispersions = result.rx['overdispersions'][0]
    overdispersion = set(overdispersions)
    assert len(overdispersion) == 1
    overdispersion = overdispersion.pop()
    return overdispersion

filename = "peaks.MiSeq_HiSeq.bed"
lines = pybedtools.BedTool(filename)
categories = {}
for line in lines:
    chromosome = line.chrom
    start = line.start
    end = line.end
    strand = line.strand
    category = line.name
    name = "%s_%d-%d_%s" % (chromosome, start, end, strand)
    categories[name] = category


datasets = ("MiSeq", "HiSeq")
timepoints = (0, 1, 4, 12, 24, 96)
replicates = (1, 2, 3)
filename = "peaks.MiSeq_HiSeq.expression.txt"
print("Reading", filename)
handle = open(filename)
line = next(handle)
words = line.split()
assert words[0] == 'peak'
libraries = words[1:]
peaks = []
counts = []
category_indices = {"Pol-II short RNA": [],
                    "sense": [],
                    "other": [],
                   }
for index, line in enumerate(handle):
    words = line.split()
    peak = words[0]
    row = array(words[1:], int)
    category = categories[peak]
    if category == "Pol-II short RNA":
        pass
    elif category in ("sense_proximal",
                      "sense_upstream",
                      "sense_distal",
                      "sense_distal_upstream"):
        category = "sense"
    else:
        category = "other"
    category_indices[category].append(index)
    peaks.append(peak)
    counts.append(row)
handle.close()

for category in category_indices:
    category_indices[category] = array(category_indices[category])

dbi_peak = "chr2_119366907-119367060_+"

counts = array(counts)
totals = sum(counts, 0)
indices = {}
for dataset in datasets:
    indices[dataset] = {}
    for timepoint in timepoints:
        indices[dataset][timepoint] = {}

for index, (library, total) in enumerate(zip(libraries, totals)):
    terms = library.split("_")
    assert len(terms) == 3
    dataset = terms[0]
    assert dataset in datasets
    timepoint = terms[1]
    assert timepoint.startswith("t")
    timepoint = int(timepoint[1:])
    assert timepoint in timepoints
    replicate = terms[2]
    assert replicate.startswith("r")
    replicate = int(replicate[1:])
    assert replicate in replicates
    indices[dataset][timepoint][replicate] = index
    print("Including %s time point %s replicate %d with total count %d" % (dataset, timepoint, replicate, total))


fig = figure(figsize=(6, 12))

ax = fig.add_subplot(111)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)
ax.set_xlabel("MiSeq time course samples, paired-end sequencing, tag count")
ax.set_ylabel("HiSeq time course samples, single-end sequencing, tag count", labelpad=27)

max_miseq = 0
max_hiseq = 0
dispersions = {}
correlations = []
correlations_raw = []
correlations_transformed = []
for i, timepoint in enumerate(timepoints):
    for j, replicate in enumerate(replicates):
        index_miseq = indices["MiSeq"][timepoint].get(replicate)
        index_hiseq = indices["HiSeq"][timepoint].get(replicate)
        if index_miseq is None or index_hiseq is None:
            continue
        column_miseq = counts[:, index_miseq]
        column_hiseq = counts[:, index_hiseq]
        row_flags = (column_miseq > 0) | (column_hiseq > 0)
        expressed_peaks = [peak for peak, flag in zip(peaks, row_flags) if flag]
        column_miseq = column_miseq[row_flags]
        column_hiseq = column_hiseq[row_flags]
        print("Time point %d hour, replicate %d:" % (timepoint, replicate))
        dispersion = calculate_dispersion(column_miseq, column_hiseq)
        print("Dispersion", dispersion)
        dispersions[(timepoint, replicate)] = dispersion
        correlation, pvalue = spearmanr(column_miseq, column_hiseq)
        print("Spearman", correlation)
        correlation, pvalue = pearsonr(column_miseq, column_hiseq)
        print("Pearson, before log transform", correlation)
        correlations_raw.append(correlation)
        correlation, pvalue = pearsonr(log(column_miseq+1), log(column_hiseq+1))
        print("Pearson, after log transform", correlation)
        correlations.append(correlation)
        # data = counts[row_flags,:][:,(index_miseq, index_hiseq)]
        # data = array(data.round(), int)
        # fit = glmGamPoi.glm_gp(data, overdispersion="global")
        # names = fit.do_slot("names")
        # assert names[1] == 'overdispersions'
        # overdispersions = array(fit[1])
        # assert std(overdispersions) < 1.e-15
        # dispersion = overdispersions[0]
        # print("Protocol standard error", sqrt(dispersion))
        fig.add_subplot(len(timepoints), len(replicates), i * len(replicates) + j + 1)
        loglog(column_miseq, column_hiseq, "ko", markersize=1)
        max_miseq = max(max_miseq, max(column_miseq))
        max_hiseq = max(max_hiseq, max(column_hiseq))
        index = expressed_peaks.index(dbi_peak)
        loglog(column_miseq[index:index+1], column_hiseq[index:index+1], "bo", markersize=5)
        # column_miseq = (2 * arcsinh(sqrt(dispersion*column_miseq)) - log(dispersion) - log(4)) / log(2)
        # column_hiseq = (2 * arcsinh(sqrt(dispersion*column_hiseq)) - log(dispersion) - log(4)) / log(2)
        # correlation, pvalue = pearsonr(column_miseq, column_hiseq)
        # print("Pearson after transformation", correlation)
        # correlations_transformed.append(correlation)
        if i == len(timepoints) - 1:
            xticks(fontsize=8)
        else:
            xticks([])
        xlim(0.9, 1e4)
        ylim(0.9, 1e7)
        if j == 0:
            ylabel("%d hour" % timepoint, fontsize=8)
            yticks(fontsize=8)
        else:
            yticks([])
        if i == 0:
            title("Replicate %d\ndispersion = %.2f" % (replicate, dispersion), fontsize=8, pad=0)
        else:
            title("dispersion = %.2f" % dispersion, fontsize=8, pad=0)

print("Mean Pearson correlation, without log transformation: %.4f" % mean(correlations_raw))
print("Mean Pearson correlation, after log transformation: %.4f" % mean(correlations))
print("Mean dispersion: %.4f" % mean(list(dispersions.values())))

subplots_adjust(left=0.15, right=0.99, bottom=0.07, top=0.95, hspace=0.3, wspace=0.1)

filename = "figure_miseq_hiseq_concordance.svg"
print("Saving figure to %s" % filename)
savefig(filename)

filename = "figure_miseq_hiseq_concordance.png"
print("Saving figure to %s" % filename)
savefig(filename)

colors = []
values = []
labels = []
for key in dispersions:
    timepoint, replicate = key
    index = timepoints.index(timepoint)
    color = cm.Blues(index/10)
    colors.append(color)
    value = dispersions[key]
    values.append(value)
    label = "%d hour, replicate %d" % (timepoint, replicate)
    labels.append(label)

figure()
values = array(values)
x = arange(len(values))
bar(x, values, color=colors, edgecolor='black')
xticks(x, labels, fontsize=8, rotation=90)
yticks(fontsize=8)
xlabel("Sample", labelpad=20, fontsize=8)
ylabel("Dispersion", fontsize=8)
subplots_adjust(bottom=0.4)

filename = "figure_miseq_hiseq_dispersion.svg"
print("Saving figure to %s" % filename)
savefig(filename)

filename = "figure_miseq_hiseq_dispersion.png"
print("Saving figure to %s" % filename)
savefig(filename)
